from pathlib import Path
from typing import Tuple
import pickle
import argparse
import os

import jax
import jax.numpy as jnp
import yaml
from flax import serialization
from yaml import SafeLoader
import omegaconf

from analysis.data_utils.analysis_repertoire import AnalysisRepertoire
from analysis.data_utils.dot_dict import DotDict
from baselines.qdax.types import Genotype, RNGKey, Descriptor, Fitness
from main_dcg_me import FactoryDCGTask



class ReEvaluator:
  def __init__(self, scoring_fn, num_reevals):
    self._scoring_fn = scoring_fn
    self._reeval_scoring_fn = jax.jit(jax.vmap(self._scoring_fn, in_axes=(None, None, 0), out_axes=1))
    self._num_reevals = num_reevals

  @property
  def num_reevals(self):
    return self._num_reevals

  @staticmethod
  def add_dimension_to_pytree(pytree, dim=0):
    return jax.tree_map(lambda x: jnp.expand_dims(x, dim), pytree)

  def reeval(self,
             genotype_single: Genotype,
             goal: Descriptor,
             random_key: RNGKey,
             add_dimension: bool = False):
    if add_dimension:
      genotype_single = self.add_dimension_to_pytree(genotype_single)

    subkeys = jax.random.split(random_key, num=self._num_reevals)

    # print(jax.tree_map(lambda x: x.shape, genotype_single))

    fit, desc, _, _ = self._reeval_scoring_fn(genotype_single, goal, jnp.asarray(subkeys))
    return fit, desc

  def mean_reevals(self, genotype_single: Genotype, goal, random_key: RNGKey, add_dimension: bool = False):
    random_key, subkey = jax.random.split(random_key)
    fit_reevals, desc_reevals = self.reeval(genotype_single, goal, random_key=subkey, add_dimension=add_dimension)
    return jnp.mean(fit_reevals, axis=1), jnp.mean(desc_reevals, axis=1)


def _evaluate_one_batch_centroids(scoring_fn, actor_gc_params, array_centroids, random_key) -> Tuple[Fitness, Descriptor]:
  random_key, subkey = jax.random.split(random_key)
  actor_gc_params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), array_centroids.shape[0], axis=0), actor_gc_params)
  fitnesses, descriptors, *_ = scoring_fn(actor_gc_params, array_centroids, random_key=subkey)
  return fitnesses, descriptors


def _reevaluate(scoring_fn, actor_gc_params, list_centroids, num_reevals, random_key) -> Tuple[Fitness, Descriptor]:
  list_fitnesses = []
  list_descriptors = []

  array_centroids = jnp.asarray(list_centroids)
  import functools
  evaluate_fn = jax.jit(functools.partial(_evaluate_one_batch_centroids, scoring_fn))

  for index in range(num_reevals):
    print(f"Reevaluating {index} / {num_reevals}")
    random_key, subkey = jax.random.split(random_key)
    # actor_gc_params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), array_centroids.shape[0], axis=0), actor_gc_params)
    fitnesses, descriptors, *_ = evaluate_fn(actor_gc_params, array_centroids, random_key=subkey)
    list_fitnesses.append(fitnesses)
    list_descriptors.append(descriptors)

  fitnesses = jnp.stack(list_fitnesses, axis=1)
  descriptors = jnp.stack(list_descriptors, axis=1)

  return fitnesses, descriptors

def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument('-l', "--path-load", type=str, required=True)
  parser.add_argument('-s', "--path-save", type=str, required=True)
  parser.add_argument('-n', "--num-reevals", type=int)

  args = parser.parse_args()
  return args


def load_config(folder_load: str):
  CONFIG_HYDRA_PATH = ".hydra/config.yaml"
  # with open(os.path.join(folder_load, CONFIG_HYDRA_PATH), "rb") as f:
  #   config = DotDict(yaml.load(f, Loader=SafeLoader))
  config = omegaconf.OmegaConf.load(os.path.join(folder_load, CONFIG_HYDRA_PATH))
  return config


def load_actor_dcg_params(folder_load: str, init_params_gc):
  path_actor_dcg_params = Path(folder_load) / "actor" / "actor.pickle"
  loaded_actor = pickle.load(open(path_actor_dcg_params, "rb"))
  return serialization.from_state_dict(init_params_gc, loaded_actor)


def reevaluate_saved_data_dcg(folder_load: str, folder_save: str, num_reevals: int = 10, resolution=None):
  assert resolution is not None
  config = load_config(folder_load)
  random_key = jax.random.PRNGKey(config.seed)

  factory_task = FactoryDCGTask(config)

  random_key, subkey = jax.random.split(random_key)
  task_info = factory_task.get_task_info(subkey)

  # reevaluator = ReEvaluator(task_info.scoring_actor_dc_fn, num_reevals)

  # centroids = task_info.centroids
  env = task_info.env
  from baselines.qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids


  grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
  centroids = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)

  loaded_actor_gc_params = load_actor_dcg_params(folder_load, task_info.init_params_gc)

  fitnesses, descriptors = _reevaluate(task_info.scoring_actor_dc_fn, loaded_actor_gc_params, centroids, num_reevals, random_key)

  analysis_repertoire = AnalysisRepertoire.create(fitnesses, descriptors, centroids)

  path_save = Path(folder_save)
  path_save.mkdir(parents=True, exist_ok=True)

  path_save_repertoire = path_save / "analysis_repertoire.pkl"
  pickle.dump(analysis_repertoire, open(path_save_repertoire, "wb"))


def main():
  args = get_args()
  reevaluate_saved_data_dcg(args.path_load, args.path_save, args.num_reevals)


if __name__ == "__main__":
  main()
